{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "from tensorflow.keras.layers import Dense, Flatten \n",
    "from tensorflow.keras import Model\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "\n",
    "# # Download a dataset\n",
    "# (x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()\n",
    "\n",
    "# # Batch and shuffle the data\n",
    "# train_ds = tf.data.Dataset.from_tensor_slices(\n",
    "#     (x_train.astype('float32') / 255, y_train)).shuffle(1024).batch(32)\n",
    "\n",
    "# test_ds = tf.data.Dataset.from_tensor_slices(\n",
    "#     (x_test.astype('float32') / 255, y_test)).batch(32)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['sky', 'clouds', 'person', 'water', 'animal']\n",
      "df shape (161789, 1)\n",
      "df shape (161789, 1)\n",
      "df shape (161789, 1)\n",
      "df shape (161789, 1)\n",
      "df shape (161789, 1)\n",
      "(60157, 5)\n",
      "data_dir: \n",
      "features_path: NUS_WIDE/NUS_WID_Low_Level_Features/Low_Level_Features\n",
      "b datasets features 64\n",
      "b datasets features 225\n",
      "b datasets features 144\n",
      "b datasets features 73\n",
      "b datasets features 128\n",
      "X image shape: (60157, 634)\n",
      "X text shape: (60157, 1000)\n",
      "<class 'numpy.ndarray'> <class 'numpy.ndarray'> <class 'numpy.ndarray'>\n"
     ]
    }
   ],
   "source": [
    "from nus_wide_data_util import get_labeled_data, get_top_k_labels\n",
    "\n",
    "class_num = 5\n",
    "\n",
    "top_k = get_top_k_labels('', top_k=class_num)\n",
    "print(top_k)\n",
    "\n",
    "data_X_image, data_X_text, data_Y = get_labeled_data('', top_k, 60000)\n",
    "print(type(data_X_image), type(data_X_text), type(data_Y))\n",
    "train_num = 50000\n",
    "\n",
    "x_train, x_test, y_train, y_test = (np.array(data_X_image[:train_num]).astype('float32'), np.array(data_X_text[:train_num]).astype('float32')), \\\n",
    "                                    (np.array(data_X_image[train_num:]).astype('float32'), np.array(data_X_text[train_num:]).astype('float32')), \\\n",
    "                                    np.array(data_Y[:train_num]).astype('float32'), np.array(data_Y[train_num:]).astype('float32')\n",
    "\n",
    "# Batch and shuffle the data\n",
    "train_ds = tf.data.Dataset.from_tensor_slices(\n",
    "    (x_train, y_train)).shuffle(1024).batch(32)\n",
    "\n",
    "test_ds = tf.data.Dataset.from_tensor_slices(\n",
    "    (x_test, y_test)).batch(32)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([2156.,  564., 3775., 1093., 2412.], dtype=float32)"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.sum(y_test, axis=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "class VFLPassiveModel(Model):\n",
    "    def __init__(self):\n",
    "        super(VFLPassiveModel, self).__init__()\n",
    "        self.flatten = Flatten()\n",
    "        self.d1 = Dense(class_num, name=\"dense1\")\n",
    "\n",
    "    def call(self, x):\n",
    "        x = self.flatten(x)\n",
    "        return self.d1(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "class VFLActiveModel(Model):\n",
    "    def __init__(self):\n",
    "        super(VFLActiveModel, self).__init__()\n",
    "        self.added = tf.keras.layers.Add()\n",
    "\n",
    "    def call(self, x):\n",
    "        x = self.added(x)\n",
    "        return tf.keras.layers.Softmax()(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1, Loss: 0.6885123252868652, Accuracy: 75.93600463867188, Test Loss: 0.6690757870674133, Test Accuracy: 76.77000427246094\n",
      "Epoch 2, Loss: 0.4942583739757538, Accuracy: 83.1719970703125, Test Loss: 0.6064096093177795, Test Accuracy: 79.06999969482422\n",
      "Epoch 3, Loss: 0.4461654722690582, Accuracy: 84.6199951171875, Test Loss: 0.5867884755134583, Test Accuracy: 79.66999816894531\n",
      "Epoch 4, Loss: 0.42267748713493347, Accuracy: 85.34400177001953, Test Loss: 0.5739940404891968, Test Accuracy: 80.26000213623047\n",
      "Epoch 5, Loss: 0.4080345034599304, Accuracy: 85.68000030517578, Test Loss: 0.5710135698318481, Test Accuracy: 80.33000183105469\n"
     ]
    }
   ],
   "source": [
    "# normal training\n",
    "\n",
    "passive_model_image = VFLPassiveModel()\n",
    "passive_model_text = VFLPassiveModel()\n",
    "active_model = VFLActiveModel()\n",
    "\n",
    "loss_object = tf.keras.losses.CategoricalCrossentropy()\n",
    "optimizer = tf.keras.optimizers.Adam()\n",
    "\n",
    "train_loss = tf.keras.metrics.Mean(name='train_loss')\n",
    "train_accuracy = tf.keras.metrics.CategoricalAccuracy(name='train_accuracy')\n",
    "\n",
    "test_loss = tf.keras.metrics.Mean(name='test_loss')\n",
    "test_accuracy = tf.keras.metrics.CategoricalAccuracy(name='test_accuracy')\n",
    "\n",
    "EPOCHS = 5\n",
    "\n",
    "for epoch in range(EPOCHS):\n",
    "    # For each batch of images and labels\n",
    "    for (images, texts), labels in train_ds:\n",
    "        with tf.GradientTape() as passive_tape:\n",
    "            # passive_model sends passive_output to active_model\n",
    "            passive_image_output = passive_model_image(images)\n",
    "            passive_text_output = passive_model_text(texts)\n",
    "            with tf.GradientTape() as active_tape:\n",
    "                active_tape.watch(passive_image_output)\n",
    "                active_tape.watch(passive_text_output)\n",
    "                active_output = active_model([passive_image_output, passive_text_output])\n",
    "                loss = loss_object(labels, active_output)\n",
    "            # active_model sends passive_output_gradients back to passive_model\n",
    "            [active_image_gradients, active_text_gradients] = active_tape.gradient(loss, [passive_image_output, passive_text_output])\n",
    "            passive_image_loss = tf.multiply(passive_image_output, active_image_gradients.numpy())\n",
    "            passive_text_loss = tf.multiply(passive_text_output, active_text_gradients.numpy())\n",
    "        [passive_image_gradients, passive_text_gradients] = \\\n",
    "        passive_tape.gradient([passive_image_loss, passive_text_loss], \\\n",
    "                              [passive_model_image.trainable_variables, passive_model_text.trainable_variables])\n",
    "        optimizer.apply_gradients(zip(passive_image_gradients, passive_model_image.trainable_variables))\n",
    "        optimizer.apply_gradients(zip(passive_text_gradients, passive_model_text.trainable_variables))\n",
    "        \n",
    "        train_loss(loss)\n",
    "        train_accuracy(labels, active_output)\n",
    "\n",
    "    for (test_images, test_texts), test_labels in test_ds:\n",
    "        passive_output = [passive_model_image(test_images), passive_model_text(test_texts)]\n",
    "        active_output = active_model(passive_output)\n",
    "        t_loss = loss_object(test_labels, active_output)\n",
    "\n",
    "        test_loss(t_loss)\n",
    "        test_accuracy(test_labels, active_output)\n",
    "\n",
    "    template = 'Epoch {}, Loss: {}, Accuracy: {}, Test Loss: {}, Test Accuracy: {}'\n",
    "    print(template.format(epoch+1,\n",
    "                        train_loss.result(),\n",
    "                        train_accuracy.result()*100,\n",
    "                        test_loss.result(),\n",
    "                        test_accuracy.result()*100))\n",
    "\n",
    "    # Reset the metrics for the next epoch\n",
    "    train_loss.reset_states()\n",
    "    train_accuracy.reset_states()\n",
    "    test_loss.reset_states()\n",
    "    test_accuracy.reset_states()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "def get_possioned_gradients(passive_output_gradients, passive_output, N, class1, class2, alpha = 1.0):\n",
    "    #passive_output_gradients -= passive_output\n",
    "    attack_mat = np.eye(N, dtype='float32')\n",
    "    attack_mat[:, class2] += attack_mat[:, class1]*alpha\n",
    "    attack_mat[:, class1] -= attack_mat[:, class1]*alpha\n",
    "    passive_output_gradients = tf.matmul(passive_output_gradients, attack_mat)\n",
    "    #passive_output_gradients += passive_output\n",
    "    return passive_output_gradients"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1, Loss: 3.917205572128296, Accuracy: 60.821998596191406, Test Loss: 4.964493751525879, Test Accuracy: 55.75\n",
      "Epoch 2, Loss: 4.0639872550964355, Accuracy: 64.24400329589844, Test Loss: 4.870058536529541, Test Accuracy: 58.64999771118164\n",
      "Epoch 3, Loss: 4.059752941131592, Accuracy: 65.54399871826172, Test Loss: 4.841984272003174, Test Accuracy: 60.28000259399414\n",
      "Epoch 4, Loss: 4.139057159423828, Accuracy: 66.14400482177734, Test Loss: 4.962497711181641, Test Accuracy: 59.380001068115234\n",
      "Epoch 5, Loss: 4.0564422607421875, Accuracy: 67.2439956665039, Test Loss: 4.755265235900879, Test Accuracy: 62.339996337890625\n"
     ]
    }
   ],
   "source": [
    "# backdoor training\n",
    "\n",
    "passive_model_image = VFLPassiveModel()\n",
    "passive_model_text = VFLPassiveModel()\n",
    "active_model = VFLActiveModel()\n",
    "\n",
    "loss_object = tf.keras.losses.CategoricalCrossentropy()\n",
    "optimizer = tf.keras.optimizers.Adam()\n",
    "optimizer_attack = tf.keras.optimizers.Adam(lr=0.1)\n",
    "\n",
    "train_loss = tf.keras.metrics.Mean(name='train_loss')\n",
    "train_accuracy = tf.keras.metrics.CategoricalAccuracy(name='train_accuracy')\n",
    "\n",
    "test_loss = tf.keras.metrics.Mean(name='test_loss')\n",
    "test_accuracy = tf.keras.metrics.CategoricalAccuracy(name='test_accuracy')\n",
    "\n",
    "EPOCHS = 5\n",
    "\n",
    "for epoch in range(EPOCHS):\n",
    "    # For each batch of images and labels\n",
    "    for (images, texts), labels in train_ds:\n",
    "        with tf.GradientTape() as passive_tape:\n",
    "            # passive_model sends passive_output to active_model\n",
    "            passive_image_output = passive_model_image(images)\n",
    "            passive_text_output = passive_model_text(texts)\n",
    "            with tf.GradientTape() as active_tape:\n",
    "                active_tape.watch(passive_image_output)\n",
    "                active_tape.watch(passive_text_output)\n",
    "                active_output = active_model([passive_image_output, passive_text_output])\n",
    "                loss = loss_object(labels, active_output)\n",
    "            # active_model sends passive_output_gradients back to passive_model\n",
    "            [active_image_gradients, active_text_gradients] = active_tape.gradient(loss, [passive_image_output, passive_text_output])\n",
    "            active_image_gradients = get_possioned_gradients(active_image_gradients, passive_image_output, 5, 3, 4)\n",
    "            #active_text_gradients = get_possioned_gradients(active_text_gradients, passive_text_output, 5, 3, 4)\n",
    "            passive_image_loss = tf.multiply(passive_image_output, active_image_gradients.numpy())\n",
    "            passive_text_loss = tf.multiply(passive_text_output, active_text_gradients.numpy())\n",
    "        [passive_image_gradients, passive_text_gradients] = \\\n",
    "        passive_tape.gradient([passive_image_loss, passive_text_loss], \\\n",
    "                              [passive_model_image.trainable_variables, passive_model_text.trainable_variables])\n",
    "        #optimizer.apply_gradients(zip(passive_image_gradients, passive_model_image.trainable_variables))\n",
    "        optimizer_attack.apply_gradients(zip(passive_image_gradients, passive_model_image.trainable_variables))\n",
    "        optimizer.apply_gradients(zip(passive_text_gradients, passive_model_text.trainable_variables))\n",
    "        #optimizer_attack.apply_gradients(zip(passive_text_gradients, passive_model_text.trainable_variables))\n",
    "        \n",
    "        train_loss(loss)\n",
    "        train_accuracy(labels, active_output)\n",
    "\n",
    "    for (test_images, test_texts), test_labels in test_ds:\n",
    "        passive_output = [passive_model_image(test_images), passive_model_text(test_texts)]\n",
    "        active_output = active_model(passive_output)\n",
    "        t_loss = loss_object(test_labels, active_output)\n",
    "\n",
    "        test_loss(t_loss)\n",
    "        test_accuracy(test_labels, active_output)\n",
    "\n",
    "    template = 'Epoch {}, Loss: {}, Accuracy: {}, Test Loss: {}, Test Accuracy: {}'\n",
    "    print(template.format(epoch+1,\n",
    "                        train_loss.result(),\n",
    "                        train_accuracy.result()*100,\n",
    "                        test_loss.result(),\n",
    "                        test_accuracy.result()*100))\n",
    "\n",
    "    # Reset the metrics for the next epoch\n",
    "    train_loss.reset_states()\n",
    "    train_accuracy.reset_states()\n",
    "    test_loss.reset_states()\n",
    "    test_accuracy.reset_states()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(10000, 634) (10000, 1000)\n",
      "(564, 634) (564, 1000)\n",
      "[318.85553   95.33841   71.49886   47.012733  31.294613]\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXcAAAD4CAYAAAAXUaZHAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjAsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+17YcXAAAOMklEQVR4nO3cf6jd9X3H8edran8wy7TzKlkSdqXLSu3AWC5W8B+nZU3tWCzUoTAnJSP9Q0FBGLb/tIUJDtY6CpuQLmK6dbVhbTG0si2zFims2qtNU9NUzGymtwnmdlZrKXMkfe+P+w09TU7uPfeec+4xn/t8wOWc8znfc8/7S8gzX74555uqQpLUlt+Y9ACSpNEz7pLUIOMuSQ0y7pLUIOMuSQ06d9IDAFx00UU1PT096TEk6azy1FNP/aSqpvo994aI+/T0NLOzs5MeQ5LOKkn++0zPeVpGkhpk3CWpQcZdkhpk3CWpQcZdkhpk3CWpQcZdkhpk3CWpQcZdkhr0hviG6jCm7/76pEcYmcP3fnDSI0hqhEfuktQg4y5JDTLuktQg4y5JDTLuktQg4y5JDTLuktQg4y5JDVoy7knekuTJJN9LciDJp7r1S5M8keS5JF9K8qZu/c3d40Pd89Pj3QVJ0qkGOXJ/Hbi2qi4HNgNbklwF/DVwX1VtAn4KbOu23wb8tKp+D7iv206StIqWjHst+Hn38Lzup4BrgX/p1ncBN3T3t3aP6Z6/LklGNrEkaUkDnXNPck6SfcAxYC/wX8ArVXW822QOWN/dXw+8CNA9/yrw231+5/Yks0lm5+fnh9sLSdKvGSjuVXWiqjYDG4ArgXf126y77XeUXqctVO2oqpmqmpmamhp0XknSAJb1aZmqegX4JnAVcEGSk1eV3AAc6e7PARsBuud/C3h5FMNKkgYzyKdlppJc0N1/K/A+4CDwGPDhbrNbgYe7+3u6x3TPf6OqTjtylySNzyDXc18H7EpyDgv/GOyuqq8l+QHwUJK/Ar4L7Oy23wn8Y5JDLByx3zSGuSVJi1gy7lW1H7iiz/rzLJx/P3X9f4EbRzKdJGlF/IaqJDXIuEtSg4y7JDXIuEtSg4y7JDXIuEtSg4y7JDXIuEtSg4y7JDXIuEtSg4y7JDXIuEtSg4y7JDXIuEtSg4y7JDXIuEtSg4y7JDXIuEtSg4y7JDXIuEtSg4y7JDXIuEtSg4y7JDVoybgn2ZjksSQHkxxIcke3/skkP06yr/u5vuc1H0tyKMmzSd4/zh2QJJ3u3AG2OQ7cVVVPJ3kb8FSSvd1z91XV3/RunOQy4Cbg3cDvAP+R5Per6sQoB5ckndmSR+5VdbSqnu7uvwYcBNYv8pKtwENV9XpV/Qg4BFw5imElSYNZ1jn3JNPAFcAT3dLtSfYneSDJhd3aeuDFnpfN0ecfgyTbk8wmmZ2fn1/24JKkMxs47knOB74M3FlVPwPuB94BbAaOAp8+uWmfl9dpC1U7qmqmqmampqaWPbgk6cwGinuS81gI+xeq6isAVfVSVZ2oql8Cn+NXp17mgI09L98AHBndyJKkpQzyaZkAO4GDVfWZnvV1PZt9CHimu78HuCnJm5NcCmwCnhzdyJKkpQzyaZmrgVuA7yfZ1619HLg5yWYWTrkcBj4KUFUHkuwGfsDCJ21u85MykrS6lox7VX2L/ufRH1nkNfcA9wwxlyRpCH5DVZIaZNwlqUHGXZIaZNwlqUHGXZIaZNwlqUHGXZIaZNwlqUHGXZIaZNwlqUHGXZIaZNwlqUHGXZIaZNwlqUHGXZIaZNwlqUHGXZIaZNwlqUHGXZIaZNwlqUHGXZIaZNwlqUHGXZIaZNwlqUFLxj3JxiSPJTmY5ECSO7r1tyfZm+S57vbCbj1JPpvkUJL9Sd4z7p2QJP26QY7cjwN3VdW7gKuA25JcBtwNPFpVm4BHu8cAHwA2dT/bgftHPrUkaVFLxr2qjlbV093914CDwHpgK7Cr22wXcEN3fyvw+VrwbeCCJOtGPrkk6YyWdc49yTRwBfAEcElVHYWFfwCAi7vN1gMv9rxsrls79XdtTzKbZHZ+fn75k0uSzmjguCc5H/gycGdV/WyxTfus1WkLVTuqaqaqZqampgYdQ5I0gIHinuQ8FsL+har6Srf80snTLd3tsW59DtjY8/INwJHRjCtJGsQgn5YJsBM4WFWf6XlqD3Brd/9W4OGe9T/vPjVzFfDqydM3kqTVce4A21wN3AJ8P8m+bu3jwL3A7iTbgBeAG7vnHgGuBw4BvwA+MtKJJUlLWjLuVfUt+p9HB7iuz/YF3DbkXJKkIfgNVUlqkHGXpAYZd0lqkHGXpAYZd0lqkHGXpAYZd0lqkHGXpAYZd0lqkHGXpAYZd0lqkHGXpAYZd0lqkHGXpAYZd0lqkHGXpAYZd0lqkHGXpAYZd0lqkHGXpAYZd0lqkHGXpAYZd0lq0JJxT/JAkmNJnulZ+2SSHyfZ1/1c3/Pcx5IcSvJskvePa3BJ0pkNcuT+ILClz/p9VbW5+3kEIMllwE3Au7vX/H2Sc0Y1rCRpMEvGvaoeB14e8PdtBR6qqter6kfAIeDKIeaTJK3AMOfcb0+yvzttc2G3th54sWebuW7tNEm2J5lNMjs/Pz/EGJKkU6007vcD7wA2A0eBT3fr6bNt9fsFVbWjqmaqamZqamqFY0iS+llR3Kvqpao6UVW/BD7Hr069zAEbezbdABwZbkRJ0nKtKO5J1vU8/BBw8pM0e4Cbkrw5yaXAJuDJ4UaUJC3XuUttkOSLwDXARUnmgE8A1yTZzMIpl8PARwGq6kCS3cAPgOPAbVV1YjyjS5LOZMm4V9XNfZZ3LrL9PcA9wwwlSRqO31CVpAYZd0lqkHGXpAYZd0lqkHGXpAYZd0lqkHGXpAYZd0lqkHGXpAYZd0lqkHGXpAYZd0lqkHGXpAYZd0lqkHGXpAYZd0lqkHGXpAYZd0lqkHGXpAYZd0lqkHGXpAYZd0lqkHGXpAYZd0lq0JJxT/JAkmNJnulZe3uSvUme624v7NaT5LNJDiXZn+Q94xxektTfIEfuDwJbTlm7G3i0qjYBj3aPAT4AbOp+tgP3j2ZMSdJyLBn3qnocePmU5a3Aru7+LuCGnvXP14JvAxckWTeqYSVJg1npOfdLquooQHd7cbe+HnixZ7u5bu00SbYnmU0yOz8/v8IxJEn9jPo/VNNnrfptWFU7qmqmqmampqZGPIYkrW0rjftLJ0+3dLfHuvU5YGPPdhuAIysfT5K0Eueu8HV7gFuBe7vbh3vWb0/yEPBe4NWTp280HtN3f33SI4zM4Xs/OOkRpGYsGfckXwSuAS5KMgd8goWo706yDXgBuLHb/BHgeuAQ8AvgI2OYWZK0hCXjXlU3n+Gp6/psW8Btww4lSRqO31CVpAYZd0lqkHGXpAYZd0lqkHGXpAYZd0lqkHGXpAYZd0lqkHGXpAYZd0lqkHGXpAYZd0lq0Eov+StNnJc7ls7MI3dJapBxl6QGGXdJapBxl6QGGXdJapBxl6QGGXdJapBxl6QGGXdJapBxl6QGDXX5gSSHgdeAE8DxqppJ8nbgS8A0cBj406r66XBjSpKWYxRH7n9YVZuraqZ7fDfwaFVtAh7tHkuSVtE4TstsBXZ193cBN4zhPSRJixg27gX8e5Knkmzv1i6pqqMA3e3F/V6YZHuS2SSz8/PzQ44hSeo17CV/r66qI0kuBvYm+eGgL6yqHcAOgJmZmRpyDklSj6HiXlVHuttjSb4KXAm8lGRdVR1Nsg44NoI5JZ2ilevZey378VjxaZkkv5nkbSfvA38EPAPsAW7tNrsVeHjYISVJyzPMkfslwFeTnPw9/1xV/5rkO8DuJNuAF4Abhx9TkrQcK457VT0PXN5n/X+A64YZSpI0HL+hKkkNMu6S1CDjLkkNMu6S1CDjLkkNMu6S1KBhLz8gSauulW/nwvi+oeuRuyQ1yLhLUoOMuyQ1yLhLUoOMuyQ1yLhLUoOMuyQ1yLhLUoOMuyQ1yLhLUoOMuyQ1yLhLUoOMuyQ1yLhLUoOMuyQ1yLhLUoOMuyQ1aGxxT7IlybNJDiW5e1zvI0k63VjinuQc4O+ADwCXATcnuWwc7yVJOt24jtyvBA5V1fNV9X/AQ8DWMb2XJOkUqarR/9Lkw8CWqvqL7vEtwHur6vaebbYD27uH7wSeHfkgo3UR8JNJDzEha3nfYW3vv/v+xva7VTXV74lzx/SG6bP2a/+KVNUOYMeY3n/kksxW1cyk55iEtbzvsLb3330/e/d9XKdl5oCNPY83AEfG9F6SpFOMK+7fATYluTTJm4CbgD1jei9J0inGclqmqo4nuR34N+Ac4IGqOjCO91pFZ80ppDFYy/sOa3v/3fez1Fj+Q1WSNFl+Q1WSGmTcJalBxn0Ja/kyCkkeSHIsyTOTnmW1JdmY5LEkB5McSHLHpGdaTUnekuTJJN/r9v9Tk55ptSU5J8l3k3xt0rOshHFfhJdR4EFgy6SHmJDjwF1V9S7gKuC2NfZn/zpwbVVdDmwGtiS5asIzrbY7gIOTHmKljPvi1vRlFKrqceDlSc8xCVV1tKqe7u6/xsJf8vWTnWr11IKfdw/P637WzKcvkmwAPgj8w6RnWSnjvrj1wIs9j+dYQ3/BtSDJNHAF8MRkJ1ld3WmJfcAxYG9VraX9/1vgL4FfTnqQlTLui1vyMgpqW5LzgS8Dd1bVzyY9z2qqqhNVtZmFb5hfmeQPJj3Takjyx8Cxqnpq0rMMw7gvzssorGFJzmMh7F+oqq9Mep5JqapXgG+ydv7/5WrgT5IcZuFU7LVJ/mmyIy2fcV+cl1FYo5IE2AkcrKrPTHqe1ZZkKskF3f23Au8DfjjZqVZHVX2sqjZU1TQLf+e/UVV/NuGxls24L6KqjgMnL6NwENjdwGUUBpbki8B/Au9MMpdk26RnWkVXA7ewcNS2r/u5ftJDraJ1wGNJ9rNwkLO3qs7KjwSuVV5+QJIa5JG7JDXIuEtSg4y7JDXIuEtSg4y7JDXIuEtSg4y7JDXo/wHFJ4ycL1/0ewAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "\n",
    "label_test = np.argmax(y_test, axis=1)\n",
    "(image_test, text_test) = x_test\n",
    "print(image_test.shape, text_test.shape)\n",
    "image_val = image_test[label_test==1]\n",
    "text_val = text_test[label_test==1]\n",
    "print(image_val.shape, text_val.shape)\n",
    "y_val = y_test[label_test==1]\n",
    "\n",
    "passive_output = [passive_model_image(image_val), passive_model_text(text_val)]\n",
    "active_output = active_model(passive_output)\n",
    "\n",
    "output_distribution = np.sum(active_output, axis=0)\n",
    "print(output_distribution)\n",
    "\n",
    "n = 5\n",
    "X = np.arange(n)\n",
    "plt.bar(X, output_distribution)\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class MyLinearModel(Model):\n",
    "  def __init__(self):\n",
    "    super(MyLinearModel, self).__init__()\n",
    "    self.flatten = Flatten()\n",
    "    self.d1 = Dense(class_num, activation='softmax', name=\"dense1\")\n",
    "\n",
    "  def call(self, x):\n",
    "    x = self.flatten(x)\n",
    "    return self.d1(x)\n",
    "\n",
    "model = MyLinearModel()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "loss_object = tf.keras.losses.SparseCategoricalCrossentropy()\n",
    "optimizer = tf.keras.optimizers.Adam()\n",
    "\n",
    "train_loss = tf.keras.metrics.Mean(name='train_loss')\n",
    "train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy')\n",
    "\n",
    "test_loss = tf.keras.metrics.Mean(name='test_loss')\n",
    "test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='test_accuracy')\n",
    "\n",
    "EPOCHS = 5\n",
    "\n",
    "for epoch in range(EPOCHS):\n",
    "  # For each batch of images and labels\n",
    "  for images, labels in train_ds:\n",
    "    with tf.GradientTape() as tape:\n",
    "      predictions = model(images)\n",
    "      loss = loss_object(labels, predictions)\n",
    "    gradients = tape.gradient(loss, model.trainable_variables)\n",
    "    print(gradients[0].shape, gradients[1].shape)\n",
    "    optimizer.apply_gradients(zip(gradients, model.trainable_variables))\n",
    "\n",
    "    train_loss(loss)\n",
    "    train_accuracy(labels, predictions)\n",
    "\n",
    "  for test_images, test_labels in test_ds:\n",
    "    predictions = model(test_images)\n",
    "    t_loss = loss_object(test_labels, predictions)\n",
    "\n",
    "    test_loss(t_loss)\n",
    "    test_accuracy(test_labels, predictions)\n",
    "\n",
    "  template = 'Epoch {}, Loss: {}, Accuracy: {}, Test Loss: {}, Test Accuracy: {}'\n",
    "  print(template.format(epoch+1,\n",
    "                        train_loss.result(),\n",
    "                        train_accuracy.result()*100,\n",
    "                        test_loss.result(),\n",
    "                        test_accuracy.result()*100))\n",
    "\n",
    "  # Reset the metrics for the next epoch\n",
    "  train_loss.reset_states()\n",
    "  train_accuracy.reset_states()\n",
    "  test_loss.reset_states()\n",
    "  test_accuracy.reset_states()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with tf.GradientTape() as tape:\n",
    "    a = tf.constant(2.)\n",
    "    b = tf.constant(1.)\n",
    "    tape.watch(a)\n",
    "    tape.watch(b)\n",
    "    c = tf.multiply(a, b)\n",
    "g = tape.gradient(c, [a, b])\n",
    "print(g)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
